{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp preproc_decorator\n",
    "import os\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocessing Decorator\n",
    "\n",
    "A decorator to simplify data preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from typing import Any, Callable, Iterable, Generator\n",
    "\n",
    "from loguru import logger\n",
    "from fastcore.basics import chunked, listify, partial\n",
    "from fastcore.parallel import num_cpus\n",
    "from joblib import Parallel, delayed\n",
    "import pandas as pd\n",
    "\n",
    "from m3tl.bert_preprocessing.create_bert_features import \\\n",
    "    create_multimodal_bert_features\n",
    "from m3tl.special_tokens import TRAIN, PREDICT\n",
    "from m3tl.utils import (get_or_make_label_encoder,\n",
    "                        load_transformer_tokenizer, set_is_pyspark)\n",
    "from m3tl.params import Params\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decorator utils functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def has_key_startswith(d: dict, prefix: str) -> bool:\n",
    "    for k in d.keys():\n",
    "        if k.startswith(prefix):\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "\n",
    "def convert_legacy_output(inp: Generator[tuple, None, None]) -> dict:\n",
    "    \"\"\"Convert legacy preproc output to dictionary\n",
    "\n",
    "    Args:\n",
    "        inp (Generator[tuple, None, None]): legacy format output\n",
    "\n",
    "    Returns:\n",
    "        dict: new format output\n",
    "\n",
    "    Yields:\n",
    "        Iterator[dict]\n",
    "    \"\"\"\n",
    "    for record in inp:\n",
    "\n",
    "        if isinstance(record, dict):\n",
    "            yield record\n",
    "        else:\n",
    "            inputs, labels = record\n",
    "\n",
    "            # need to do conversion\n",
    "            if isinstance(inputs, dict) and not has_key_startswith(inputs, 'inputs_'):\n",
    "                new_format_record = {'inputs_{}'.format(\n",
    "                    k): v for k, v in inputs.items()}\n",
    "            elif isinstance(inputs, dict):\n",
    "                new_format_record = inputs\n",
    "            else:\n",
    "                new_format_record = {'inputs_text': inputs}\n",
    "\n",
    "            if isinstance(labels, dict) and not has_key_startswith(labels, 'labels_'):\n",
    "                new_format_record.update({\n",
    "                    'labels_{}'.format(k): v for k, v in labels.items()\n",
    "                })\n",
    "            elif isinstance(labels, dict):\n",
    "                new_format_record.update(labels)\n",
    "            else:\n",
    "                new_format_record['labels'] = labels\n",
    "            yield new_format_record\n",
    "\n",
    "\n",
    "def input_format_check(inp: dict, mode: str):\n",
    "    if not isinstance(inp, dict):\n",
    "        raise ValueError(\n",
    "            \"preproc outout content should be dict, got: {}\".format(type(inp)))\n",
    "\n",
    "    inputs_columns = [k for k in inp.keys() if k.startswith('inputs')]\n",
    "    if not inputs_columns:\n",
    "        raise ValueError(\n",
    "            'inputs should has key with prefix \"inputs\", keys: {}'.format(inp.keys()))\n",
    "\n",
    "    if mode != PREDICT:\n",
    "        labels_columns = [k for k in inp.keys() if k.startswith('labels')]\n",
    "        if not labels_columns:\n",
    "            raise ValueError(\n",
    "                'inputs should has key with prefix \"labels\", keys: {}'.format(inp.keys()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "from m3tl.utils import get_phase\n",
    "from m3tl.predefined_problems.test_data import generate_fake_data\n",
    "\n",
    "\n",
    "for input_format in ['gen_dict_tuple', 'dict', 'gen_list_tuple', 'gen_dict', 'dict_tuple']:\n",
    "    fake_data = generate_fake_data(output_format=input_format)\n",
    "    if input_format == 'dict_tuple':\n",
    "        inp, lab = fake_data\n",
    "        inp = pd.DataFrame(inp).to_dict('recrods')\n",
    "        fake_data = zip(inp, lab)\n",
    "    data_iter = convert_legacy_output(fake_data)\n",
    "    # print(next(data_iter))\n",
    "    # # print(fake_data)\n",
    "    if input_format in ['gen_dict_tuple']:\n",
    "        assert list(next(data_iter).keys()) == [\n",
    "            'inputs_text', 'inputs_array', 'inputs_cate', 'inputs_cate_modal_type', 'inputs_cate_modal_info', 'labels']\n",
    "\n",
    "    if input_format == 'gen_dict':\n",
    "        assert list(next(data_iter).keys()) == ['inputs_record_id', 'inputs_text', 'inputs_array',\n",
    "                                                'inputs_cate', 'inputs_cate_modal_type', 'inputs_cate_modal_info', 'labels']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "def none_generator(length: int = None) -> Generator[None, None, None]:\n",
    "    if length is None:\n",
    "        while True:\n",
    "            yield None\n",
    "    else:\n",
    "        for _ in range(length):\n",
    "            yield None\n",
    "\n",
    "\n",
    "def convert_data_to_features(problem: str, data_iter: Iterable, params: Params, label_encoder: Any, tokenizer: Any, mode=TRAIN) -> Iterable[dict]:\n",
    "\n",
    "    if mode != PREDICT:\n",
    "        problem_type = params.problem_type[problem]\n",
    "\n",
    "        # whether this problem is sequential labeling\n",
    "        # for sequential labeling, targets needs to align with any\n",
    "        # change of inputs\n",
    "        is_seq = problem_type in ['seq_tag']\n",
    "    else:\n",
    "        problem_type = 'cls'\n",
    "        is_seq = False\n",
    "\n",
    "    part_fn = partial(create_multimodal_bert_features, problem=problem,\n",
    "                      label_encoder=label_encoder,\n",
    "                      params=params,\n",
    "                      tokenizer=tokenizer,\n",
    "                      mode=mode,\n",
    "                      problem_type=problem_type,\n",
    "                      is_seq=is_seq)\n",
    "    preprocess_buffer = params.preprocess_buffer\n",
    "    data_buffer_list = []\n",
    "    num_cpus = params.num_cpus if params.num_cpus > 0 else num_cpus()\n",
    "    # no easy fix for prediction in multiprocessing\n",
    "    # phase is not shared between processes\n",
    "    num_cpus = 1 if mode == PREDICT else num_cpus\n",
    "    for data_buffer_list in chunked(data_iter, chunk_sz=preprocess_buffer):\n",
    "        per_cpu_chunk = listify(chunked(data_buffer_list, n_chunks=num_cpus))\n",
    "        res_gen = Parallel(num_cpus)(delayed(part_fn)(example_list=d_list)\n",
    "                                     for d_list in per_cpu_chunk)\n",
    "        for d_list in res_gen:\n",
    "            for d in d_list:\n",
    "                yield d\n",
    "\n",
    "\n",
    "def convert_data_to_features_pyspark(\n",
    "        problem: str, dataframe, params: Params, label_encoder: Any, tokenizer: Any, mode=TRAIN):\n",
    "\n",
    "    # whether this problem is sequential labeling\n",
    "    # for sequential labeling, targets needs to align with any\n",
    "    # change of inputs\n",
    "    from copy import deepcopy\n",
    "\n",
    "    params_here = deepcopy(params)\n",
    "    del params_here.read_data_fn\n",
    "\n",
    "    params.num_cpus = 1\n",
    "\n",
    "    dataframe = dataframe.mapPartitions(lambda x: convert_data_to_features(\n",
    "        problem=problem, data_iter=x, params=params_here, tokenizer=tokenizer, label_encoder=label_encoder, mode=mode))\n",
    "\n",
    "    return dataframe\n",
    "\n",
    "def check_if_le_created(problem: str, params: Params):\n",
    "\n",
    "    try:\n",
    "        le_called: bool = params.get_problem_info(problem=problem, info_name='label_encoder_called')\n",
    "        if not le_called:\n",
    "            raise ValueError('If your preprocessing function returns'\n",
    "                    ' a generator or pyspark RDD, you have to call `m3tl.utils.get_or_make_label_encoder` manually. \\n'\n",
    "                    'If you\\'re implementing custom get or make label encoder fn, please specify '\n",
    "                    'num_classes. Example: \\n'\n",
    "                    'params.set_problem_info(problem=problem, info_name=\"num_classes\", info=100)'.format(problem))\n",
    "    except KeyError:\n",
    "        KeyError('If your preprocessing function returns'\n",
    "                    ' a generator or pyspark RDD, you have to call `m3tl.utils.get_or_make_label_encoder` manually. \\n'\n",
    "                    'If you\\'re implementing custom get or make label encoder fn, please specify '\n",
    "                    'num_classes. Example: \\n'\n",
    "                    'params.set_problem_info(problem=problem, info_name=\"num_classes\", info=100)'.format(problem))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decorator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def preprocessing_fn(func: Callable):\n",
    "    \"\"\"Usually used as a decorator.\n",
    "\n",
    "    The input and output signature of decorated function should be:\n",
    "    func(params: m3tl.Params,\n",
    "         mode: str) -> Union[Generator[X, y], Tuple[List[X], List[y]]]\n",
    "\n",
    "    Where X can be:\n",
    "    - Dicitionary of 'a' and 'b' texts: {'a': 'a test', 'b': 'b test'}\n",
    "    - Text: 'a test'\n",
    "    - Dicitionary of modalities: {'text': 'a test', 'image': np.array([1,2,3])}\n",
    "\n",
    "    Where y can be:\n",
    "    - Text or scalar: 'label_a'\n",
    "    - List of text or scalar: ['label_a', 'label_a1'] (for seq2seq and seq_tag)\n",
    "\n",
    "    This decorator will do the following things:\n",
    "    - load tokenizer\n",
    "    - call func, save as example_list\n",
    "    - create label_encoder and count the number of rows of example_list\n",
    "    - create bert features from example_list and write tfrecord\n",
    "\n",
    "    Args:\n",
    "        func (Callable): preprocessing function for problem\n",
    "    \"\"\"\n",
    "    def wrapper(params, mode, get_data_num=False, write_tfrecord=True):\n",
    "        problem = func.__name__\n",
    "\n",
    "        tokenizer = load_transformer_tokenizer(\n",
    "            params.transformer_tokenizer_name, params.transformer_tokenizer_loading)\n",
    "\n",
    "        # proc func can return one of the following types:\n",
    "        # - Generator\n",
    "        # - Tuple[list] or list\n",
    "        # - pyspark RDD\n",
    "        example_list = func(params, mode)\n",
    "\n",
    "        if isinstance(example_list, tuple) or isinstance(example_list, list):\n",
    "            try:\n",
    "                inputs_list, target_list = example_list\n",
    "            except ValueError:\n",
    "                inputs_list = example_list\n",
    "                target_list = none_generator(len(inputs_list))\n",
    "\n",
    "            if len(inputs_list) == 0:\n",
    "                raise ValueError(\n",
    "                    'problem {} preproc fn returns empty data'.format(problem))\n",
    "\n",
    "            # ugly handling\n",
    "            if isinstance(inputs_list, dict):\n",
    "                inputs_list = pd.DataFrame(inputs_list).to_dict('records')\n",
    "\n",
    "            example_list = zip(inputs_list, target_list)\n",
    "            example_list = convert_legacy_output(example_list)\n",
    "\n",
    "            if mode != PREDICT:\n",
    "                label_encoder = get_or_make_label_encoder(\n",
    "                    params, problem=problem, mode=mode, label_list=target_list)\n",
    "            else:\n",
    "                label_encoder = None\n",
    "\n",
    "            return convert_data_to_features(\n",
    "                problem=problem,\n",
    "                data_iter=example_list,\n",
    "                params=params,\n",
    "                label_encoder=label_encoder,\n",
    "                tokenizer=tokenizer,\n",
    "                mode=mode\n",
    "            )\n",
    "        elif isinstance(example_list, Iterable):\n",
    "            # trigger making label encoder\n",
    "            try:\n",
    "                next(example_list)\n",
    "            except StopIteration:\n",
    "                raise StopIteration(\n",
    "                    'problem {} preproc fn returns empty data'.format(problem))\n",
    "\n",
    "            example_list = func(params, mode)\n",
    "            example_list = convert_legacy_output(example_list)\n",
    "\n",
    "            # create label encoder\n",
    "            if mode != PREDICT:\n",
    "                check_if_le_created(problem, params)\n",
    "                label_encoder = get_or_make_label_encoder(\n",
    "                    params, problem=problem, mode=mode, label_list=[], overwrite=False)\n",
    "            else:\n",
    "                label_encoder = None\n",
    "\n",
    "            return convert_data_to_features(\n",
    "                problem=problem,\n",
    "                data_iter=example_list,\n",
    "                params=params,\n",
    "                label_encoder=label_encoder,\n",
    "                tokenizer=tokenizer,\n",
    "                mode=mode\n",
    "            )\n",
    "        else:\n",
    "            try:\n",
    "                from pyspark import RDD\n",
    "            except ImportError:\n",
    "                raise ImportError(\n",
    "                    \"pyspark is not installed, in this case, preproc \"\n",
    "                    \"function should return a generator, a tuple or a list.\")\n",
    "\n",
    "            if not isinstance(example_list, RDD):\n",
    "                raise ValueError(\"preproc function should return a generator, a tuple, \"\n",
    "                \"a list or a pyspark RDD, got {} from problem {}\".format(\n",
    "                    type(example_list), problem))\n",
    "\n",
    "            set_is_pyspark(True)\n",
    "            if params.pyspark_output_path is None:\n",
    "                raise ValueError(\n",
    "                    \"preproc function of {} returns RDD but \"\n",
    "                    \"params.pyspark_output_path is not set.\".format(problem))\n",
    "                    \n",
    "            if mode != PREDICT:\n",
    "                check_if_le_created(problem, params)\n",
    "                label_encoder = get_or_make_label_encoder(\n",
    "                    params, problem=problem, mode=mode, label_list=[], overwrite=False)\n",
    "            else:\n",
    "                label_encoder = None\n",
    "\n",
    "            return convert_data_to_features_pyspark(\n",
    "                problem=problem,\n",
    "                dataframe=example_list,\n",
    "                params=params,\n",
    "                label_encoder=label_encoder,\n",
    "                tokenizer=tokenizer,\n",
    "                mode=mode\n",
    "            )\n",
    "    return wrapper\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## User-Defined Preprocessing Function\n",
    "\n",
    "The user-defined preprocessing function should return two elements: features and targets, except for `pretrain` problem type.\n",
    "\n",
    "For features and targets, it can be one of the following format:\n",
    "- tuple of list\n",
    "- generator of tuple\n",
    "\n",
    "Please note that if preprocessing function returns generator of tuple, then corresponding problem cannot be chained using `&`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "import m3tl\n",
    "from m3tl.params import Params\n",
    "from typing import Tuple\n",
    "import shutil\n",
    "import tempfile\n",
    "import numpy as np\n",
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-22 20:19:16.587 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_ner, problem type: seq_tag\n",
      "2021-06-22 20:19:16.588 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls\n",
      "2021-06-22 20:19:16.588 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_cls, problem type: cls\n",
      "2021-06-22 20:19:16.589 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_masklm, problem type: masklm\n",
      "2021-06-22 20:19:16.589 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_regression, problem type: regression\n",
      "2021-06-22 20:19:16.590 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit\n",
      "2021-06-22 20:19:16.590 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_premask_mlm, problem type: premask_mlm\n",
      "2021-06-22 20:19:16.590 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem fake_contrastive_learning, problem type: contrastive_learning\n",
      "2021-06-22 20:19:16.591 | WARNING  | m3tl.base_params:assign_problem:642 - base_dir and dir_name arguments will be deprecated in the future. Please use model_dir instead.\n",
      "2021-06-22 20:19:16.592 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n",
      "2021-06-22 20:19:22.227 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "\n",
    "# setup params for testing\n",
    "from m3tl.test_base import TestBase\n",
    "tb = TestBase()\n",
    "params = tb.params\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weibo_fake_ner\n",
      "dict_keys(['text_input_ids', 'text_mask', 'text_segment_ids', 'weibo_fake_ner_label_ids'])\n",
      "weibo_fake_multi_cls\n",
      "dict_keys(['text_input_ids', 'text_mask', 'text_segment_ids', 'weibo_fake_multi_cls_label_ids', 'array_input_ids', 'array_mask', 'array_segment_ids', 'cate_input_ids', 'cate_mask', 'cate_segment_ids'])\n",
      "weibo_fake_cls\n",
      "dict_keys(['text_input_ids', 'text_mask', 'text_segment_ids', 'weibo_fake_cls_label_ids'])\n",
      "weibo_masklm\n",
      "dict_keys(['text_input_ids', 'text_mask', 'text_segment_ids', 'masked_lm_positions', 'masked_lm_ids', 'masked_lm_weights'])\n",
      "weibo_fake_regression\n",
      "dict_keys(['record_id', 'text_input_ids', 'text_mask', 'text_segment_ids', 'weibo_fake_regression_label_ids', 'array_input_ids', 'array_mask', 'array_segment_ids', 'cate_input_ids', 'cate_mask', 'cate_segment_ids'])\n",
      "weibo_fake_vector_fit\n",
      "dict_keys(['text_input_ids', 'text_mask', 'text_segment_ids', 'weibo_fake_vector_fit_label_ids'])\n",
      "weibo_premask_mlm\n",
      "dict_keys(['text_input_ids', 'text_mask', 'text_segment_ids', 'weibo_premask_mlm_masked_lm_positions', 'weibo_premask_mlm_masked_lm_ids', 'weibo_premask_mlm_masked_lm_weights'])\n",
      "fake_contrastive_learning\n",
      "dict_keys(['text_input_ids', 'text_mask', 'text_segment_ids', 'fake_contrastive_learning_label_ids'])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "from m3tl.special_tokens import TRAIN\n",
    "\n",
    "for problem_name, fn in params.read_data_fn.items():\n",
    "    print(problem_name)\n",
    "    print(next(fn(params, TRAIN)).keys())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tuple of List\n",
    "\n",
    "#### Single Modal\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str) -> Tuple[list, list]:\n",
    "    \"Simple example to demonstrate singe modal tuple of list return\"\n",
    "    if mode == m3tl.TRAIN:\n",
    "        toy_input = ['this is a toy input' for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    else:\n",
    "        toy_input = ['this is a toy input for test' for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    return toy_input, toy_target\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-22 20:19:23.062 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "import pandas as pd\n",
    "from pyspark import RDD\n",
    "from copy import copy\n",
    "\n",
    "def preproc_dec_test(fn=toy_cls, run_train=True):\n",
    "    params.register_problem(problem_name='toy_cls',\n",
    "                            problem_type='cls', processing_fn=fn)\n",
    "    copy_params = copy(params)\n",
    "    copy_params.assign_problem('toy_cls')\n",
    "    if run_train:\n",
    "        res = fn(params=copy_params, mode=TRAIN)\n",
    "        if isinstance(res, RDD):\n",
    "            print(res.take(1))\n",
    "            return\n",
    "        next(res)\n",
    "    fn(params=copy_params, mode=PREDICT)\n",
    "\n",
    "\n",
    "\n",
    "preproc_dec_test()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-22 20:19:23.338 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "# predict multiprocessing test\n",
    "params.num_cpus = 2\n",
    "\n",
    "\n",
    "@preprocessing_fn\n",
    "def wrapper(params, mode):\n",
    "    for t in [{'text': 'this is a toy input',\n",
    "               'image': np.random.uniform(size=(16))} for _ in range(10)]:\n",
    "        yield t\n",
    "\n",
    "preproc_dec_test(wrapper, run_train=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Multi-modal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str) -> Tuple[list, list]:\n",
    "    \"Simple example to demonstrate multi-modal tuple of list return\"\n",
    "    if mode == m3tl.TRAIN:\n",
    "        toy_input = [{'text': 'this is a toy input',\n",
    "                      'image': np.random.uniform(size=(16))} for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    else:\n",
    "        toy_input = [{'text': 'this is a toy input for test',\n",
    "                      'image': np.random.uniform(size=(16))} for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "\n",
    "    return toy_input, toy_target\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-22 20:19:24.038 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "preproc_dec_test()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### A, B Token Multi-modal\n",
    "\n",
    "TODO: Implement this. Not working yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str) -> Tuple[list, list]:\n",
    "    \"Simple example to demonstrate A, B token multi-modal tuple of list return\"\n",
    "    if mode == m3tl.TRAIN:\n",
    "        toy_input = [\n",
    "            {\n",
    "                'a': {\n",
    "                    'text': 'this is a toy input',\n",
    "                    'image': np.random.uniform(size=(16))\n",
    "                },\n",
    "                'b': {\n",
    "                    'text': 'this is a toy input',\n",
    "                    'image': np.random.uniform(size=(16))\n",
    "                }\n",
    "            } for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    else:\n",
    "        toy_input = [\n",
    "            {\n",
    "                'a': {\n",
    "                    'text': 'this is a toy input for test',\n",
    "                    'image': np.random.uniform(size=(16))\n",
    "                },\n",
    "                'b': {\n",
    "                    'text': 'this is a toy input for test',\n",
    "                    'image': np.random.uniform(size=(16))\n",
    "                }\n",
    "            } for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "\n",
    "    return toy_input, toy_target\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # hide\n",
    "# params.register_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)\n",
    "# assert (10, 1)==toy_cls(params=params, mode=m3tl.TRAIN, get_data_num=True, write_tfrecord=False)\n",
    "\n",
    "# shutil.rmtree(os.path.join(params.tmp_file_dir, 'toy_cls'))\n",
    "# toy_cls(params=params, mode=m3tl.TRAIN, get_data_num=False, write_tfrecord=True)\n",
    "# assert os.path.exists(os.path.join(params.tmp_file_dir, 'toy_cls', 'train_feature_desc.json'))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generator of Tuple\n",
    "\n",
    "#### Single Modal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str) -> Tuple[list, list]:\n",
    "    \"Simple example to demonstrate singe modal tuple of list return\"\n",
    "    if mode == m3tl.TRAIN:\n",
    "        toy_input = ['this is a toy input' for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    else:\n",
    "        toy_input = ['this is a toy input for test' for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    for i, t in zip(toy_input, toy_target):\n",
    "        yield i, t\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-22 20:19:27.539 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "preproc_dec_test()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Multi-modal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str) -> Tuple[list, list]:\n",
    "    \"Simple example to demonstrate multi-modal tuple of list return\"\n",
    "    if mode == m3tl.TRAIN:\n",
    "        toy_input = [{'text': 'this is a toy input',\n",
    "                      'image': np.random.uniform(size=(16))} for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    else:\n",
    "        toy_input = [{'text': 'this is a toy input for test',\n",
    "                      'image': np.random.uniform(size=(16))} for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    for i, t in zip(toy_input, toy_target):\n",
    "        yield i, t\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-22 20:19:28.204 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "preproc_dec_test()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pyspark dataframe\n",
    "\n",
    "#### single modal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-22 20:19:31.418 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "import pandas as pd\n",
    "from pyspark_crud.core import set_globals, get_globals\n",
    "\n",
    "set_globals()\n",
    "_, _, sc, _ = get_globals()\n",
    "\n",
    "\n",
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str) -> RDD:\n",
    "    if mode == m3tl.TRAIN:\n",
    "        d = {\n",
    "            'inputs': ['this is a toy input' for _ in range(10)],\n",
    "            'labels': ['a' for _ in range(10)]\n",
    "        }\n",
    "    else:\n",
    "        d = {\n",
    "            'inputs': ['this is a toy input for test' for _ in range(10)],\n",
    "            'labels': ['a' for _ in range(10)]\n",
    "        }\n",
    "    # transform d to records shape\n",
    "    d = pd.DataFrame(d).to_dict('records')\n",
    "    rdd = sc.parallelize(d)\n",
    "    return rdd\n",
    "\n",
    "\n",
    "params.pyspark_output_path = tempfile.mkdtemp()\n",
    "preproc_dec_test()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### multimodal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-22 20:19:31.948 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str) -> RDD:\n",
    "    get_or_make_label_encoder(params=params, problem='toy_cls', label_list=['a'], mode=mode)\n",
    "    if mode == m3tl.TRAIN:\n",
    "        d = {\n",
    "            'inputs_text': ['this is a toy input' for _ in range(10)],\n",
    "            'inputs_image': [np.random.uniform(size=(16)).tolist() for _ in range(10)],\n",
    "            'labels': ['a' for _ in range(10)]\n",
    "        }\n",
    "    else:\n",
    "        d = {\n",
    "            'inputs_text': ['this is a toy input test' for _ in range(10)],\n",
    "            'inputs_image': [np.random.uniform(size=(16)).tolist() for _ in range(10)],\n",
    "            'labels': ['a' for _ in range(10)]\n",
    "        }\n",
    "    d = pd.DataFrame(d).to_dict('records')\n",
    "    rdd = sc.parallelize(d)\n",
    "    return rdd\n",
    "\n",
    "\n",
    "preproc_dec_test()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 64-bit ('base': conda)",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
